import torch
from torch.optim import Optimizer
from torch import nn


class esgd(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.99, 0.9), alp=-5, eps=1e-6,
                 weight_decay=0):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(
                "Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(
                "Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, eps=eps, alp=alp,
                        weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        torch.nn.utils.clip_grad_norm_(
            parameters=[
                p for group in self.param_groups for p in group['params']],
            max_norm=1.0,
            norm_type=2
        )

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'We does not support sparse gradients, consider SparseAdam instad.')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Store previous gradient
                    state['prev_grad'] = torch.zeros_like(p.data)

                    state['exp_grad_diff'] = torch.zeros_like(p.data)

                exp_avg = state['exp_avg']
                beta1, beta2 = group['betas']
                alp = group['alp']

                # if 'prev_grad' in state:
                #     grad_diff = grad - state['prev_grad']
                #     print(state['prev_grad'])
                #     print(grad)
                #     print(grad_diff)
                # else:
                #     grad_diff = torch.zeros_like(grad)
                if state['step']>=2:
                    grad_diff = grad - state['prev_grad']
                    # print(state['prev_grad'])
                    # print(grad)
                    # print(grad_diff)
                else:
                    grad_diff = torch.zeros_like(grad)

                # a_t: state['exp_grad_diff']
                # print(state['exp_grad_diff'])
                state['exp_grad_diff'].mul_(beta1).add_(grad_diff, alpha=1 - beta1)
                # print(state['exp_grad_diff'])
                # if 'bias_corr_a' not in state:
                #     state['bias_corr_a'] = torch.zeros_like(p.data)
                # if state['step']>=2:
                #     bias_correction1 = 1 - beta1 ** (state['step']-1)
                #     state['bias_corr_a'] = state['exp_grad_diff'] / bias_correction1
                # print(state['bias_corr_a'])

                modified_grad = grad.add(state['exp_grad_diff'], alpha=alp)
                # modified_grad = grad.add(state['bias_corr_a'], alpha=alp)
                # print(modified_grad)
                state['step'] += 1

                # Decay the first and second moment running average coefficient
                # m_t
                # exp_avg.mul_(beta2).add_(modified_grad, alpha=1 - beta2)
                exp_avg.mul_(beta2).add_(modified_grad)

                scaled_lr = group['lr']
                update = exp_avg
                
                p.data.add_(update, alpha=-scaled_lr)
                state['prev_grad'] = grad.clone()
        return loss

def create_esgd_optimizer(model, lr, betas=(0.9, 0.9), alpha=-5, eps=1e-6,
                          weight_decay=0):
    optimizer = esgd(model, lr, betas=betas, eps=eps,alp=alpha,
                     weight_decay=weight_decay)
    return optimizer